[Rust] ortでonnxモデルを使って推論したりWASMにしたりしてみる
Introduction
以前、BurnというRustの機械学習フレームワークで
ONNXファイルを変換して使うという記事を書きました。
問題なく変換して推論までできましたが、onnxファイルをそのまま使いたいケースもあります。
というわけで、今回はONNX RuntimeのRustラッパー「ort」を使ってみます。
また、wasmpackを使ってWASMにしてChrome Extensionから使ってみます。
[補足] ONNX?
ONNXは、さまざまな機械学習フレームワーク間で使用できる共通フォーマットです。
これを使うことにより、PytorchでトレーニングしたモデルをTensorFlowで使う
みたいなことが容易にできます。
ort?
ortは、ONNXランタイム用のRustバインディングです。
ここで紹介されていますが、ortとONNX Runtimeを併用することで、
さまざまなMLモデル (YOLOv8、BERT、LLaMAなど) を(ほぼ)すべてのハードウェア上で実行でき、
さらに多くのケースでPyTorchよりも高速に実行させることができます。
(機械学習モデルをONNXグラフに変換することで最適化も可能となる)
Environment
- MacBook Pro (14-inch, M3, 2023)
- OS : MacOS 14.5
- Rust : 1.78.0
- wasm-pack 0.12.1
- gh : 2.49.2
Try
では、ortをつかってMNISTの数値認識をやってみましょう。 まずはcargoでプロジェクトを作成します。
% cargo new ort-rs % cd ort-rs
Cargo.tomlは↓のようになってます。
cargo addでortいれるとversion1が入るので注意。
[dependencies] image = "0.25.1" ndarray = "0.15.6" ort = "2.0.0-rc.2" tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] }
このあたりからmnistのonnxファイルを持ってきます。
あとはテストで確認するための数値画像はこのへんとかで用意。
それらのファイルはプロジェクトroot(ort-rsの下)においておきます。
main.rsはこんな感じです。
ort::Sessionでmnist.onnxをloadして、
ndarrayで画像を変換後にSession::runで推論を実行します。
use ort::{GraphOptimizationLevel, Session, Value}; use std::error::Error; use ndarray::{Array, Ix4, ArrayD}; use std::collections::HashMap; use image::io::Reader as ImageReader; /// メイン関数 fn main() -> Result<(), Box<dyn Error>> { // ログの初期化 tracing_subscriber::fmt::init(); // モデルのセッションを作成 let model = create_session(include_bytes!("../mnist.onnx"))?; // 画像をロードして前処理を行う(8の画像) let img_array = preprocess_image(include_bytes!("../mnist_8.jpg"))?; // 推論を実行して結果を表示 run_inference(&model, img_array)?; Ok(()) } fn create_session(model_data: &[u8]) -> Result<Session, Box<dyn Error>> { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)? .commit_from_memory(model_data)?; Ok(session) } fn preprocess_image(img_data: &[u8]) -> Result<Array<f32, Ix4>, Box<dyn Error>> { // 画像をロード let img = ImageReader::new(std::io::Cursor::new(img_data)) .with_guessed_format()? .decode()? .to_luma8(); // 画像を28x28にリサイズ let img = image::imageops::resize(&img, 28, 28, image::imageops::FilterType::Nearest); // 画像データを正規化してndarrayに変換 let img_array = Array::from_shape_vec( (1, 1, 28, 28), img.iter().map(|&p| p as f32 / 255.0).collect(), )?; Ok(img_array) } fn run_inference(model: &Session, img_array: Array<f32, Ix4>) -> Result<(), Box<dyn Error>> { // 入力データをHashMapに格納 let input_tensor = Value::from_array(img_array)?; let mut inputs = HashMap::new(); inputs.insert("Input3", input_tensor); // 推論を実行 let outputs = model.run(inputs)?; // 結果を表示 for (name, tensor) in outputs.iter() { println!("Output {}: {:?}", name, tensor); let array_view: ndarray::ArrayViewD<f32> = tensor.try_extract_tensor()?; let array: ArrayD<f32> = array_view.to_owned(); println!("Tensor values: {:?}", array); // 最も高い値を持つインデックスを見つける let max_index = array.iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) .map(|(index, _)| index) .unwrap(); println!("Predicted digit: {}", max_index); } Ok(()) }
ちなみに、入力パラメータ名って何つかえばいいかわからなかったので、
GenAIに聞いたら↓で調べろといわれた。
import onnx # ONNXモデルをロード model = onnx.load("mnist.onnx") # モデルの入力名を表示 for input in model.graph.input: print(input.name)
実行するとこんな感じです。
一応、画像はちゃんと判定されてますね。
% cargo run Output Plus214_Output_0: Value { ・・・ } Tensor values: [[1.1330107, -5.077395, 8.586635, 4.278495, -7.7753954, 1.1661655, -4.3016477, -10.987418, 14.0324135, 1.9482136]], shape=[1, 10], strides=[10, 1], layout=CFcf (0xf), dynamic ndim=2 Predicted digit: 8
Convert to WASM & Use in Chrome Extension
では次に、mnistプログラムをWASM化して
Chrome Extensionで使ってみます。
onnxのWASM化はまだexperimentalとのことですが、
ortのリポジトリにはWASMのサンプルがあるので、それをつかってみます。
(↑のコードを使おうとしたらうまくいかなかった)
ortのリポジトリをcloneしましょう。
% gh repo clone pykeio/ort
ort/examples/webassemblyにそのまま動くサンプルがあります。
これを少しだけ変えてChrome Extensionで動かしてみます。
Cargo.tomlにcrateを追加します。
serde-wasm-bindgen = "0.6.5"
webassembly/src/lib.rsを下記のように少し修正。
サンプルではortファイルをloadして使ってます。
use image::{ImageBuffer, Luma, Pixel}; use ort::{ArrayExtensions, Session}; use wasm_bindgen::prelude::*; use wasm_bindgen::JsValue; use ndarray::Array4; use serde_wasm_bindgen::to_value; static MODEL_BYTES: &[u8] = include_bytes!("mnist.ort"); #[wasm_bindgen] pub fn classify_image(image_bytes: &[u8]) -> Result<JsValue, JsValue> { let session_builder = match Session::builder() { Ok(builder) => builder, Err(e) => return Err(JsValue::from_str(&format!("Could not create session builder: {:?}", e))), }; let session = match session_builder.commit_from_memory_directly(MODEL_BYTES) { Ok(s) => s, Err(e) => return Err(JsValue::from_str(&format!("Could not read model from memory: {:?}", e))), }; let image_buffer: ImageBuffer<Luma<u8>, Vec<u8>> = match image::load_from_memory(image_bytes) { Ok(img) => img.to_luma8(), Err(e) => return Err(JsValue::from_str(&format!("Could not load image from memory: {:?}", e))), }; let array = Array4::from_shape_fn((1, 1, 28, 28), |(_, _, j, i)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); let channels = pixel.channels(); (channels[0] as f32) / 255.0 }); let inputs = match ort::inputs![array] { Ok(i) => i, Err(e) => return Err(JsValue::from_str(&format!("Error creating inputs: {:?}", e))), }; let outputs = match session.run(inputs) { Ok(o) => o, Err(e) => return Err(JsValue::from_str(&format!("Error during inference: {:?}", e))), }; let probabilities: Vec<f32> = match outputs[0].try_extract_tensor() { Ok(tensor) => tensor.softmax(ndarray::Axis(1)).iter().copied().collect(), Err(e) => return Err(JsValue::from_str(&format!("Error extracting tensor: {:?}", e))), }; // 確率をパーセンテージ形式で小数点第10位までフォーマット let formatted_probabilities: Vec<String> = probabilities.iter().map(|&x| format!("{:.10}%", x * 100.0)).collect(); Ok(to_value(&formatted_probabilities).map_err(|e| JsValue::from_str(&format!("Error serializing output: {:?}", e)))?) } #[cfg(test)] mod tests { use super::*; use wasm_bindgen_test::console_log; use wasm_bindgen_test::wasm_bindgen_test; use serde_wasm_bindgen::from_value; wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] fn run_test() { use tracing::Level; use tracing_subscriber::fmt; use tracing_subscriber_wasm::MakeConsoleWriter; #[cfg(target_arch = "wasm32")] ort::wasm::initialize(); fmt() .with_ansi(false) .with_max_level(Level::DEBUG) .with_writer(MakeConsoleWriter::default().map_trace_level_to(Level::DEBUG)) .without_time() .init(); std::panic::set_hook(Box::new(console_error_panic_hook::hook)); let image_bytes: &[u8] = include_bytes!("../../../tests/data/mnist_5.jpg"); let result = classify_image(image_bytes).unwrap(); // JsValueをVec<String>に変換 let formatted_probabilities: Vec<String> = from_value(result).unwrap(); console_log!("Probabilities: {:?}", formatted_probabilities); } }
テストしてみる
wasmpackでtestできます。
必要なパッケージをインストールします。
% brew install chromedriver
このままだとテスト実行時にchromedriverが起動しないので、
ctrlを押しながらクリックして警告ダイアログがでないようにします。
そしてテスト実行。動いてます。
% wasm-pack test --headless --chrome [INFO]: 🎯 Checking for the Wasm target... Running headless tests in Chrome on `http://127.0.0.1:56986/` Try find `webdriver.json` for configure browser's capabilities: Not found running 1 test test ortwasm::tests::run_test ... ok test result: ok. 1 passed; 0 failed; 0 ignored; 0 filtered out ・・・
buildコマンドでWASMを作成します。
成功するとpkgディレクトリにwasmやjsができてます。
% wasm-pack build --release --target web % ls -l pkg/ total 13696 -rw-r--r--@ 1 2903 Jun 5 21:10 ortwasm.d.ts -rw-r--r--@ 1 9310 Jun 5 21:10 ortwasm.js -rw-r--r--@ 1 6985877 Jun 5 21:10 ortwasm_bg.wasm -rw-r--r--@ 1 056 Jun 5 21:10 ortwasm_bg.wasm.d.ts -rw-r--r--@ 1 231 Jun 5 21:10 package.json
適当なChrome Extensionを作成
数値を書いてincludeしたWASMで推論するサンプルを作ります。
まずはextensionディレクトリを作成して、
さきほどのpkgディレクトリをコピーしておきます。
あとはExtensionのコードを作成しましょう。
GenAIで「フリーハンドでキャンバスに数値書いてjpgにして、
そのデータをwasmに渡すChrome Extension作って」
と言ったらほとんど生成してくれます。
extension/manifest.jsonは下記。
WASMを実行するためにcontent_security_policyが必要です。
{ "manifest_version": 3, "name": "MNIST WASM Chrome Extension", "version": "1.0", "permissions": ["storage"], "action": { "default_popup": "popup.html", "default_icon": { "16": "images/icon16.png" } }, "background": { "service_worker": "background.js" }, "content_security_policy": { "extension_pages": "script-src 'self' 'wasm-unsafe-eval'" } }
extension/popup.htmlです。
scriptタグの「type="module"」を指定して
Javascriptモジュールを使用します。
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <title>Draw and Predict</title> <style> #canvas { border: 1px solid black; } </style> </head> <body> <h1>Draw a Digit</h1> <canvas id="canvas" width="280" height="280"></canvas> <br> <button id="predictButton">Predict</button> <pre id="result"></pre> <script type="module" src="popup.js"></script> </body> </html>
extension/popup.jsは下記のようになっています。(抜粋)
Rustコードで指定したclassify_imageをimportしてます。
import init, { classify_image } from './pkg/ortwasm.js'; document.addEventListener('DOMContentLoaded', () => { init('./pkg/ortwasm_bg.wasm').then(() => { //キャンパス描画処理 ・・・・・・・・・・・ // Predictボタンが押されたときの処理 predictButton.addEventListener('click', () => { //キャンパスをJPEGに変換など const resizedCanvas = document.createElement('canvas'); ・・・・ resizedCanvas.toBlob((blob) => { const reader = new FileReader(); reader.onloadend = () => { const arrayBuffer = reader.result; const uint8Array = new Uint8Array(arrayBuffer); // WASMのMNISTモデルに送信 predictDigit(uint8Array); }; reader.readAsArrayBuffer(blob); }, 'image/jpeg'); // JPEG形式で保存 }); // WASMのMNISTモデルに送信する関数 async function predictDigit(imageData) { try { const result = await classify_image(imageData); document.getElementById('result').textContent = JSON.stringify(result, null, 2); } catch (error) { console.error('Error during prediction:', error); } } }).catch(console.error); });
ちなみに、テスト時に書いたキャンバスをjpgとしてダウンロードしたかったとき、
↓の関数実行したらそのままダウンロードできて便利だった。
// ローカルにJPEG画像を保存する関数 function saveImage(blob) { const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; a.download = 'draw_image.jpg'; document.body.appendChild(a); a.click(); document.body.removeChild(a); URL.revokeObjectURL(url); }
chrome://extensions/で、「パッケージ化されたいない拡張機能を読み込む」
を押してextensionディレクトリを指定してインストールします。
実行してみると↓みたいな感じです。ボタンを押すと推論結果が表示されてます。
Summary
今回はONNXランタイム用Rustライブラリortをつかってみました。
onnxがそのまま使えるのは便利ですし、WASMで使えるのも良いです。